10735
17844
Jaký je nejjednodušší způsob, jak transformovat tenzor tvaru (batch_size, výška, šířka) naplněný n hodnotami na tenzor tvaru (batch_size, n, výška, šířka)?
Níže jsem vytvořil řešení, ale vypadá to, že existuje jednodušší a rychlejší způsob, jak to udělat
def batch_tensor_to_onehot (tnsr, třídy):
tnsr = tnsr.unsqueeze (1)
res = []
pro cls v rozsahu (třídách):
res.append ((tnsr == cls) .long ())
vrátit torch.cat (res, dim = 1) 
Můžete použít torch.nn.functional.one_hot.
Pro váš případ:
a = torch.nn.functional.one_hot (tnsr, num_classes = classes)
out = a.permute (0, 3, 1, 2)
|
Můžete také použít Tensor.scatter_, který se vyhýbá .permute, ale je pravděpodobně obtížnější pochopit než přímá metoda navržená @Alpha.
def batch_tensor_to_onehot (tnsr, třídy):
result = torch.zeros (tnsr.shape [0], classes, * tnsr.shape [1:], dtype = torch.long, device = tnsr.device)
result.scatter_ (1, tnsr.unsqueeze (1), 1)
vrátit výsledek
Výsledky srovnávání
Byl jsem zvědavý a rozhodl jsem se porovnat tři přístupy. Zjistil jsem, že mezi navrhovanými metodami, pokud jde o velikost dávky, šířku nebo výšku, se nezdá významný relativní rozdíl. Rozlišujícím faktorem byl především počet tříd. Samozřejmě jako u každého standardního počtu ujetých kilometrů se může lišit.
Benchmarky byly shromážděny pomocí náhodných indexů a pomocí velikosti šarže, výšky, šířky = 100. Každý experiment byl opakován 20krát s uváděným průměrem. Experiment num_classes = 100 se spustí jednou před profilováním pro zahřátí.
Výsledky CPU ukazují, že původní metoda byla pravděpodobně nejlepší pro num_classes menší než asi 30, zatímco pro GPU se přístup scatter_ zdá být nejrychlejší.
Testy prováděné na Ubuntu 18.04, NVIDIA 2060 Super, i7-9700K
Kód používaný pro srovnávání je uveden níže:
importovat pochodeň
z tqdm importovat tqdm
čas importu
importovat matplotlib.pyplot jako plt
def batch_tensor_to_onehot_slavka (tnsr, třídy):
tnsr = tnsr.unsqueeze (1)
res = []
pro cls v rozsahu (třídách):
res.append ((tnsr == cls) .long ())
vrátit torch.cat (res, dim = 1)
def batch_tensor_to_onehot_alpha (tnsr, třídy):
result = torch.nn.functional.one_hot (tnsr, num_classes = classes)
návrat result.permute (0, 3, 1, 2)
def batch_tensor_to_onehot_jodag (tnsr, třídy):
result = torch.zeros (tnsr.shape [0], classes, * tnsr.shape [1:], dtype = torch.long, device = tnsr.device)
result.scatter_ (1, tnsr.unsqueeze (1), 1)
vrátit výsledek
def main ():
num_classes = [2, 10, 25, 50, 100]
výška = 100
šířka = 100
bs = [100] * 20
pro d v ['cpu', 'cuda']:
times_slavka = []
times_alpha = []
times_jodag = []
warmup = Pravda
pro c v tqdm ([num_classes [-1]] + num_classes, ncols = 0):
tslavka = 0
talpha = 0
tjodag = 0
pro b v bs:
tnsr = torch.randint (c, (b, výška, šířka)). do (zařízení = d)
t0 = time.time ()
y = batch_tensor_to_onehot_slavka (tnsr, c)
torch.cuda.synchronize ()
tslavka + = time.time () - t0
pokud ne, zahřátí:
times_slavka.append (tslavka / len (bs))
pro b v bs:
tnsr = torch.randint (c, (b, výška, šířka)). do (zařízení = d)
t0 = time.time ()
y = batch_tensor_to_onehot_alpha (tnsr, c)
torch.cuda.synchronize ()
talpha + = time.time () - t0
pokud ne, zahřátí:
times_alpha.append (talpha / len (bs))
pro b v bs:
tnsr = torch.randint (c, (b, výška, šířka)). do (zařízení = d)
t0 = time.time ()
y = batch_tensor_to_onehot_jodag (tnsr, c)
torch.cuda.synchronize ()
tjodag + = time.time () - t0
pokud ne, zahřátí:
times_jodag.append (tjodag / len (bs))
warmup = False
fig = plt.figure ()
ax = fig.subplots ()
ax.plot (num_classes, times_slavka, label = 'Slavka-cat')
ax.plot (num_classes, times_alpha, label = 'Alpha-one_hot')
ax.plot (num_classes, times_jodag, label = 'jodag-scatter_')
ax.set_xlabel ('num_classes')
ax.set_ylabel ('time (s)')
ax.set_title (f '{d} benchmark')
ax.legend ()
plt.savefig (f '{d} .png')
plt.show ()
pokud __name__ == "__main__":
hlavní()
|
Tvoje odpověď
StackExchange.ifUsing ("editor", function () {
StackExchange.using ("externalEditor", function () {
StackExchange.using ("snippets", function () {
StackExchange.snippets.init ();
});
});
}, „code-snippets“);
StackExchange.ready (funkce () {
var channelOptions = {
tagy: "" .split (""),
id: "1"
};
initTagRenderer ("". split (""), "" .split (""), channelOptions);
StackExchange.using ("externalEditor", function () {
// Je nutné po úryvcích vypálit editor, pokud jsou úryvky povoleny
if (StackExchange.settings.snippets.snippetsEnabled) {
StackExchange.using ("snippets", function () {
createEditor ();
});
}
else {
createEditor ();
}
});
funkce createEditor () {
StackExchange.prepareEditor ({
useStacksEditor: false,
heartbeatType: 'answer',
autoActivateHeartbeat: false,
convertImagesToLinks: true,
noModals: true,
showLowRepImageUploadWarning: true,
reputationToPostImages: 10,
bindNavPrevention: true,
postfix: "",
imageUploader: {
brandingHtml: "Powered by \ u003ca href = \" https: //imgur.com/ \ "\ u003e \ u003csvg class = \" svg-icon \ "width = \" 50 \ "height = \" 18 \ "viewBox = \ "0 0 50 18 \" fill = \ "none \" xmlns = \ "http: //www.w3.org/2000/svg \" \ u003e \ u003cpath d = \ "M46.1709 9.17788C46.1709 8.26454 46,2665 7,94324 47,1084 7,58816C47.4091 7,46349 47,7169 7,36433 48,0099 7,26993C48,9099 6,97977 49,672 6,73443 49,672 5,93063C49,672 5,22043 48,9832 4,61182 48,1414 4,61182C47,4335 4,61182 46,7256 4,91628 4,91650 4,91650 4,416 164 4,45 43,1481 6,59048V11,9512C43,1481 13,2535 43,6264 13,8962 44,6595 13,8962C45,6924 13,8962 46,1709 13,253546.1709 11,9512V9.17788Z \ "/ \ u003e \ u003cpath d = \" M32.492 10.1419C32.492 12.6954 34.1182 14.0484 37.0451 14.0484C39.9723 14.0484 41.5985 12.6954 41.5985 10.1419V6.59049C41.5324 422 422 422 422 422 422 422 422 422 38,5948 5,28821 38,5948 6,59049V9,60062C38,5948 10,8521 38,2696 11,5455 37,0451 11,5455C35,8209 11,5455 35,4954 10,8521 35,4954 9,60062V6,59049C35,4954 5,288821 35,0173 4,66232 34,0034 4,6232C32,970 32,32432 fill-rule = \ "evenodd \" clip-rule = \ "evenodd \" d = \ "M25.6622 17.6335C27.8049 17.6335 29.3739 16.9402 30.2537 15.6379C30.8468 14.7755 30.9615 13.5579 30.9615 11.9512V6.59049C30.9615 5.28821 30.4833 4.66231 29,4502 4,66231C28,9913 4,66231 28,4555 4,94978 28,1109 5,50789C27,499 4,86533 26,7335 4,56087 25,7005 4,56087C23.1369 4,56087 21,0134 6,57349 21,0134 9,279,32 21,126 28,13,13 13,13 16 13,13 13,13 13 13,13 13 13,13 13 13,13 13 13,13 C28. 1256 12,884 28,1301 12,9342 28,1301 12,983 C28.1301 14,4373 27,2502 15,2321 25,777 15,2321C24,8349 15,2321 24,1352 14,9821 23,5661 14,7787C23,176 14,6393 22,8472 14,5218 22,5437 14,5218C21,7977 14,52 21,242 212 212 212 212 212 212 212 212 212 212 212 212 212 212 212 212 212 212 212 212 212 212 212 212 212 C24.1317 7,94324 24,9928 7,09766 26,1024 7,09766C27.2119 7,09766 28,0918 7,94324 28,0918 9,27932 C28.0918 10,6321 27,2311 11,5116 26,1024 11,5116C24,9737 11,5116 24,1317 10,6491 24,1317 9,2932 Z \ "/ \ u003" 8045 13,2535 17,2637 13,8962 18,2965 13,8962C19,3298 13,8962 19,8079 13,2535 19,8079 11,9512V8,12928C19,8079 5,82936 18,4879 4,62866 16,4027 4,62866C15,1594 4,62866 14,279 4,983 75 13,358 10,457 10,285 10,456 10,458 10,458 10,458 10,457 10,455 10,455 10,455 10,457 10,455 10,455 58314 4,9328 7,10506 4,6632 6,52203 4,66232 C5,47873 4,66232 5 00066 5,2888 5 00066 6,59049 V11,9512C5 00066 13,2535 5,47873 13,8962 6,51203 13,8962C7,54479 13,8962 8,0232 13 0,2535 8,0232 11,9512V8,90741C8,0232 7,58817 8,44431 6,91179 9,53458 6,91179C10,5104 6,91179 10,893 7,58817 10,893 8,94108V11,9512C10,893 13,2535 11,3711 13,8962 12,4044 13,8962C13,4375 13,8962 13,97 13,137,13 13,15 13,13 C16.4027 6,91179 16,8045 7,58817 16,8045 8,94108V11.9512Z \ "/ \ u003e \ u003cpath d = \" M3.31675 6,59049C3.31675 5,2821 2,83866 4,66232 1,82471 4,66232C0,791758 4,66232 0,313354 5,28813 0,13354 6,13354 6,13354 6,13354 6,13354 6,13354 1,82471 13,8962C2.85798 13,8962 3,31675 13,2535 3,31675 11,9512V6,59049Z \ "/ \ u003e \ u003cpath d = \" M1.87209 0,400291C0,843612 0,400291 0 1,1159 0 1,98861C0 2,87869 0,822846 3,57677 1,876,987 2,876 1,876 1,876 2,876 1,876 1,876 1,876 1,876 1,876 1,876 1,876 1,87 C3.7234 1.1159 2.90056 0.400291 1.87209 0.400291Z \ "fill = \" # 1BB76E \ "/ \ u003e \ u003c / svg \ u003e \ u003c / a \ u003e",
contentPolicyHtml: "Uživatelské příspěvky jsou licencovány pod \ u003ca href = \" https: //stackoverflow.com/help/licensing \ "\ u003ecc by-sa \ u003c / a \ u003e \ u003ca href = \" https://stackoverflow.com / legal / content-policy \ "\ u003e (obsahová politika) \ u003c / a \ u003e",
allowUrls: true
},
onDemand: true,
discardSelector: ".discard-answer"
, okamžitěShowMarkdownHelp: true, enableTables: true, enableSnippets: true
});
}
});
Děkujeme, že jste přispěli odpovědí na Stack Overflow!
Nezapomeňte na otázku odpovědět. Uveďte podrobnosti a sdílejte svůj výzkum!
Ale vyhnout se ...
Žádáme o pomoc, vysvětlení nebo reagujeme na jiné odpovědi.
Vytváření prohlášení na základě názoru; podpořte je referencemi nebo osobními zkušenostmi.
Chcete-li se dozvědět více, přečtěte si naše tipy na psaní skvělých odpovědí.
Koncept uložen
Koncept zahozen
Zaregistrujte se nebo se přihlaste
StackExchange.ready (funkce () {
StackExchange.helpers.onClickDraftSave ('# login-link');
});
Zaregistrujte se pomocí Google
Zaregistrujte se pomocí Facebooku
Zaregistrujte se pomocí e-mailu a hesla
Předložit
Zveřejněte jako host
název
E-mailem
Povinné, ale nikdy zobrazené
StackExchange.ready (
funkce () {
StackExchange.openid.initPostLogin ('. New-post-login', 'https% 3a% 2f% 2fstackoverflow.com% 2fquestions% 2f62245173% 2fpytorch-transform-tensor-to-one-hot% 23new-answer', 'question_page' );
}
);
Zveřejněte jako host
název
E-mailem
Povinné, ale nikdy zobrazené
Zveřejněte svou odpověď
Vyřadit
Kliknutím na „Odeslat odpověď“ vyjadřujete souhlas s našimi podmínkami služby, zásadami ochrany osobních údajů a zásadami používání souborů cookie
Toto není odpověď, kterou hledáte? Přečtěte si další otázky týkající se značek python pytorch tensor one-hot-encoding